"""
model name : 深度学习
file       : nms.py
information:
    author : OuYang
    time   : 2025/1/18
"""
import torch


def nms(tensors, iou_threshold, nums_classes=1, b=2):
    """
    预测输出后，进行非极大值抑制
    :param tensors: 输出tensor
    :param iou_threshold: iou阈值
    :param nums_classes: 预测类别数
    :param b: 每个类预测框个数
    :return: 计算后的tensor
    """
    scores = torch.zeros((nums_classes, b))
    print(scores)
    print(tensors[..., b * 5:])
    for i in range(nums_classes):
        max_value, max_idx = torch.max(tensors[..., b*5:], dim=3)
        print(max_value)
        print(max_idx)
        break


if __name__ == '__main__':
    output = torch.tensor([
        [[[0.5167, 0.4930, 0.5121, 0.4949, 0.5272, 0.4896, 0.5086, 0.4784,
           0.4929, 0.5080, 0.4908],
          [0.5061, 0.4977, 0.5103, 0.5057, 0.5116, 0.4972, 0.4913, 0.5058,
           0.5034, 0.4933, 0.5068],
          [0.4811, 0.4996, 0.5117, 0.5063, 0.4998, 0.4921, 0.5022, 0.4958,
           0.4981, 0.4811, 0.4842],
          [0.5018, 0.4871, 0.4897, 0.5125, 0.4925, 0.4934, 0.5053, 0.4921,
           0.4935, 0.5044, 0.4858],
          [0.5004, 0.4905, 0.5061, 0.5107, 0.5152, 0.5074, 0.4942, 0.4774,
           0.4909, 0.5010, 0.5030],
          [0.4888, 0.4893, 0.5101, 0.4897, 0.4868, 0.5111, 0.4982, 0.5012,
           0.5014, 0.5085, 0.5088],
          [0.5123, 0.4959, 0.5148, 0.5000, 0.4911, 0.5130, 0.5118, 0.5074,
           0.4993, 0.5040, 0.5047]],

         [[0.5063, 0.5066, 0.4871, 0.4903, 0.4856, 0.4935, 0.5040, 0.4868,
           0.5073, 0.5146, 0.5070],
          [0.5014, 0.4934, 0.4997, 0.5129, 0.4906, 0.5010, 0.5109, 0.4908,
           0.5007, 0.5087, 0.5009],
          [0.4925, 0.4942, 0.5074, 0.4814, 0.4965, 0.4999, 0.5112, 0.5052,
           0.5192, 0.4907, 0.5031],
          [0.4755, 0.4863, 0.4340, 0.4505, 0.5054, 0.4652, 0.4915, 0.4390,
           0.4446, 0.5096, 0.8230],
          [0.5033, 0.4953, 0.5205, 0.5017, 0.5098, 0.5164, 0.5152, 0.4981,
           0.4844, 0.4927, 0.4958],
          [0.5074, 0.5079, 0.5064, 0.5018, 0.4858, 0.4914, 0.5128, 0.5064,
           0.4912, 0.5053, 0.4811],
          [0.4976, 0.4949, 0.4902, 0.5049, 0.4814, 0.4885, 0.5104, 0.4962,
           0.4945, 0.4896, 0.4962]],

         [[0.5155, 0.5073, 0.4961, 0.5015, 0.4832, 0.4937, 0.4858, 0.5064,
           0.5002, 0.4967, 0.4688],
          [0.5006, 0.4859, 0.4943, 0.4881, 0.5120, 0.5212, 0.5045, 0.5011,
           0.4928, 0.5027, 0.5147],
          [0.5986, 0.6614, 0.2036, 0.4056, 0.5028, 0.5920, 0.6648, 0.2056,
           0.4016, 0.4918, 0.9194],
          [0.7238, 0.3724, 0.4588, 0.5213, 0.5014, 0.7358, 0.3681, 0.4449,
           0.4961, 0.5113, 0.9172],
          [0.5042, 0.4971, 0.4983, 0.5110, 0.4962, 0.4897, 0.5221, 0.5085,
           0.5009, 0.4784, 0.4883],
          [0.4780, 0.4868, 0.4928, 0.5140, 0.4869, 0.4936, 0.4813, 0.4923,
           0.5035, 0.4966, 0.5059],
          [0.4947, 0.4949, 0.4885, 0.4940, 0.4961, 0.4882, 0.5014, 0.5065,
           0.5039, 0.4956, 0.4885]],

         [[0.4922, 0.5091, 0.5126, 0.4960, 0.4960, 0.5104, 0.4956, 0.4887,
           0.5016, 0.5105, 0.4883],
          [0.3793, 0.5672, 0.7104, 0.4564, 0.5040, 0.3700, 0.5531, 0.7046,
           0.4660, 0.5188, 0.8072],
          [0.4591, 0.7644, 0.1928, 0.3635, 0.5033, 0.4713, 0.7596, 0.2050,
           0.3474, 0.5110, 0.9582],
          [0.4474, 0.2115, 0.2174, 0.2691, 0.4874, 0.4531, 0.2138, 0.2080,
           0.2619, 0.5160, 0.9700],
          [0.4928, 0.6325, 0.2566, 0.4613, 0.5145, 0.4756, 0.6228, 0.2516,
           0.4638, 0.5026, 0.8687],
          [0.5183, 0.5090, 0.5122, 0.4943, 0.4940, 0.4991, 0.5175, 0.5160,
           0.5061, 0.5121, 0.4975],
          [0.4821, 0.4951, 0.5053, 0.5059, 0.5104, 0.4935, 0.5012, 0.5035,
           0.5009, 0.4950, 0.4849]],

         [[0.4917, 0.4941, 0.4871, 0.5114, 0.4756, 0.5038, 0.4758, 0.5013,
           0.5030, 0.4739, 0.5152],
          [0.5066, 0.4997, 0.5105, 0.5078, 0.4911, 0.4933, 0.4937, 0.4980,
           0.5202, 0.5026, 0.4891],
          [0.2946, 0.5920, 0.2611, 0.4100, 0.5015, 0.2860, 0.5882, 0.2677,
           0.4095, 0.4877, 0.9150],
          [0.2109, 0.5187, 0.2888, 0.3935, 0.5091, 0.2087, 0.5200, 0.2809,
           0.3874, 0.4981, 0.9248],
          [0.4930, 0.5036, 0.5081, 0.5034, 0.4989, 0.5176, 0.5120, 0.5003,
           0.4995, 0.5056, 0.5043],
          [0.4790, 0.5232, 0.5045, 0.4996, 0.4901, 0.5250, 0.4926, 0.4816,
           0.5117, 0.4863, 0.5050],
          [0.4979, 0.4922, 0.5057, 0.5082, 0.5083, 0.5038, 0.5007, 0.5093,
           0.5054, 0.5088, 0.5215]],

         [[0.5037, 0.4936, 0.5187, 0.5100, 0.5057, 0.4974, 0.5170, 0.5011,
           0.5079, 0.5070, 0.4891],
          [0.5085, 0.4980, 0.4887, 0.5007, 0.5037, 0.5018, 0.4880, 0.4829,
           0.4997, 0.5296, 0.5072],
          [0.5100, 0.5189, 0.4746, 0.4974, 0.4968, 0.5060, 0.5096, 0.4854,
           0.4973, 0.4814, 0.5232],
          [0.4979, 0.5089, 0.4911, 0.5011, 0.4947, 0.5136, 0.5062, 0.4934,
           0.4886, 0.4935, 0.4989],
          [0.4956, 0.5010, 0.5021, 0.5025, 0.4873, 0.5001, 0.4756, 0.4953,
           0.5039, 0.5087, 0.5030],
          [0.4915, 0.5070, 0.4992, 0.4991, 0.4869, 0.4907, 0.4843, 0.4920,
           0.4982, 0.4980, 0.4906],
          [0.5158, 0.4940, 0.5036, 0.4948, 0.4950, 0.5128, 0.4947, 0.5079,
           0.4955, 0.4994, 0.5074]],

         [[0.4985, 0.5099, 0.5066, 0.5153, 0.4834, 0.5229, 0.4913, 0.4905,
           0.4876, 0.5031, 0.5116],
          [0.5061, 0.5069, 0.5137, 0.4803, 0.5077, 0.4942, 0.5105, 0.5076,
           0.5016, 0.4906, 0.4793],
          [0.4895, 0.5230, 0.5102, 0.5009, 0.5068, 0.4754, 0.4939, 0.4957,
           0.4934, 0.5159, 0.5003],
          [0.5728, 0.6701, 0.1895, 0.3564, 0.4856, 0.5778, 0.6763, 0.2106,
           0.3470, 0.5154, 0.8216],
          [0.5153, 0.5110, 0.4971, 0.4891, 0.5087, 0.5016, 0.5165, 0.4838,
           0.5102, 0.5133, 0.4907],
          [0.4937, 0.5095, 0.4936, 0.5062, 0.5072, 0.5213, 0.4921, 0.4986,
           0.4789, 0.5164, 0.5177],
          [0.4941, 0.5148, 0.5033, 0.4901, 0.5135, 0.5205, 0.5136, 0.5182,
           0.4977, 0.4731, 0.4972]]]
    ])
    nms(output, iou_threshold=0.2)
